Aplicação de algumas técnicas de XAI de diferentes tipos/abordagens na classificação binária de doenças cardíacas (heart disase dataset, disponível em https://archive.ics.uci.edu/ml/datasets/heart+disease).
Neste notebooks são utilizados as seguintes técnicas:
Na primeira etapa é realizado um tratamento básico no dataset e depois é realizado o treinamento dos modelos de ML.
Em seguida é a parte de XAI.
seed = 42
import pandas as pd
import numpy as np
from sklearn.impute import SimpleImputer
from sklearn.model_selection import train_test_split, cross_val_score
from sklearn.preprocessing import MinMaxScaler
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, roc_auc_score, plot_confusion_matrix
import matplotlib.pyplot as plt
columns = ['age', 'sex', 'cp', 'trestbps', 'chol', 'fbs', 'restecg', 'thalach', 'exang', 'oldpeak', 'slope', 'ca', 'thal', 'target']
Attribute Information (abbreviature - description):
data = pd.read_csv('https://archive.ics.uci.edu/ml/machine-learning-databases/heart-disease/processed.cleveland.data', names=columns, na_values='?')
print(data.shape)
data.head()
(303, 14)
| age | sex | cp | trestbps | chol | fbs | restecg | thalach | exang | oldpeak | slope | ca | thal | target | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 63.0 | 1.0 | 1.0 | 145.0 | 233.0 | 1.0 | 2.0 | 150.0 | 0.0 | 2.3 | 3.0 | 0.0 | 6.0 | 0 |
| 1 | 67.0 | 1.0 | 4.0 | 160.0 | 286.0 | 0.0 | 2.0 | 108.0 | 1.0 | 1.5 | 2.0 | 3.0 | 3.0 | 2 |
| 2 | 67.0 | 1.0 | 4.0 | 120.0 | 229.0 | 0.0 | 2.0 | 129.0 | 1.0 | 2.6 | 2.0 | 2.0 | 7.0 | 1 |
| 3 | 37.0 | 1.0 | 3.0 | 130.0 | 250.0 | 0.0 | 0.0 | 187.0 | 0.0 | 3.5 | 3.0 | 0.0 | 3.0 | 0 |
| 4 | 41.0 | 0.0 | 2.0 | 130.0 | 204.0 | 0.0 | 2.0 | 172.0 | 0.0 | 1.4 | 1.0 | 0.0 | 3.0 | 0 |
A variável alvo (target) contém 4 valores, entretanto os estudos com essa base de dados estão concentrados em simplesmente distinguir a existência de doença (1,2 ou 3) da ausência (0).
data['target'].value_counts().plot(kind='bar', title='Distribuição do Target', figsize=(10, 6), rot=0)
plt.show()
data['target'] = data['target'].apply(lambda x: 1 if x > 0 else 0)
data['target'].value_counts().plot(kind='bar', title='Distribuição do Target após tratamento', figsize=(10, 6), rot=0)
<matplotlib.axes._subplots.AxesSubplot at 0x7f5d69118f28>
data['target'].value_counts(normalize=True)
0 0.541254 1 0.458746 Name: target, dtype: float64
data.isnull().sum()
age 0 sex 0 cp 0 trestbps 0 chol 0 fbs 0 restecg 0 thalach 0 exang 0 oldpeak 0 slope 0 ca 4 thal 2 target 0 dtype: int64
fig, axs = plt.subplots(1, 2, figsize=(14,6))
data['thal'].value_counts(dropna=False).plot(kind='bar', rot=0, ax=axs[0], title='thal')
data['ca'].value_counts(dropna=False).plot(kind='bar', rot=0, ax=axs[1], title='ca')
<matplotlib.axes._subplots.AxesSubplot at 0x7f5d690ec390>
inputer_thal = SimpleImputer(strategy='most_frequent')
data = pd.DataFrame(inputer_thal.fit_transform(data), columns=columns)
Utilização da técnica One Hot Encoder para realizar o encoding das variáveis categóricas. Por simplicidade e para ajustar os nomes das colunas com mais facilidade em um DataFrame, foi utilizado o método get_dummies do Pandas, mas o efeito é o mesmo da técnica mencionada.
categorical_cols = ['cp', 'restecg', 'thal']
data[categorical_cols] = data[categorical_cols].apply(lambda x: x.astype(int))
def transform_categorical_cols(col, data):
df_dummie = pd.get_dummies(data[col])
df_dummie.columns = [col + '_' + str(int(c)) for c in df_dummie.columns]
data = pd.concat([df_dummie, data.drop(col, axis=1)], axis=1)
return data
for col in categorical_cols:
data = transform_categorical_cols(col, data)
data.head(3)
| thal_3 | thal_6 | thal_7 | restecg_0 | restecg_1 | restecg_2 | cp_1 | cp_2 | cp_3 | cp_4 | ... | sex | trestbps | chol | fbs | thalach | exang | oldpeak | slope | ca | target | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 0 | 1 | 0 | 0 | 0 | 1 | 1 | 0 | 0 | 0 | ... | 1.0 | 145.0 | 233.0 | 1.0 | 150.0 | 0.0 | 2.3 | 3.0 | 0.0 | 0.0 |
| 1 | 1 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 1 | ... | 1.0 | 160.0 | 286.0 | 0.0 | 108.0 | 1.0 | 1.5 | 2.0 | 3.0 | 1.0 |
| 2 | 0 | 0 | 1 | 0 | 0 | 1 | 0 | 0 | 0 | 1 | ... | 1.0 | 120.0 | 229.0 | 0.0 | 129.0 | 1.0 | 2.6 | 2.0 | 2.0 | 1.0 |
3 rows × 21 columns
data['target'] = data['target'].astype(int)
X_train, X_test, y_train, y_test = train_test_split(data.drop('target', axis=1), data['target'], test_size=0.2, random_state=seed)
feature_range = (0, 1)
scaler = MinMaxScaler(feature_range)
feature_cols = X_train.columns
X_train = scaler.fit_transform(X_train)
X_train = pd.DataFrame(X_train, columns=feature_cols, index=y_train.index)
X_test = scaler.transform(X_test)
X_test = pd.DataFrame(X_test, columns=feature_cols, index=y_test.index)
model = RandomForestClassifier(random_state=seed)
scores = cross_val_score(estimator=model, X=X_train, y=y_train, cv=5, n_jobs=-1)
# cross_val_score(estimator=model, X=X_train, y=y_train, cv=5, n_jobs=-1, scoring='roc_auc').mean()
print('Acuracia Media:', np.mean(scores))
print('Desvio Padrao:', np.std(scores))
print(scores)
Acuracia Media: 0.8098639455782312 Desvio Padrao: 0.05324881261743182 [0.7755102 0.85714286 0.72916667 0.8125 0.875 ]
model = RandomForestClassifier(random_state=seed)
model.fit(X_train, y_train)
RandomForestClassifier(random_state=42)
pred = model.predict(X_test)
pred_scores = model.predict_proba(X_test)[:, 1]
print('Acc score:', accuracy_score(y_test, pred))
print('F1 score:', f1_score(y_test, pred))
print('AUC ROC:', roc_auc_score(y_test, pred_scores))
plot_confusion_matrix(model, X_test, y_test, normalize='true')
Acc score: 0.8524590163934426 F1 score: 0.8524590163934426 AUC ROC: 0.9401939655172413
<sklearn.metrics._plot.confusion_matrix.ConfusionMatrixDisplay at 0x7f5d6901f588>
import tensorflow as tf
tf.compat.v1.disable_v2_behavior() # disable TF2 behaviour as alibi code still relies on TF1 constructs
from tensorflow.keras.layers import Conv2D, Dense, Dropout, Flatten, MaxPooling2D, Input
from tensorflow.keras.models import Model, load_model, Sequential
from tensorflow.keras.utils import to_categorical
print('TF version: ', tf.__version__)
print('Eager execution enabled: ', tf.executing_eagerly()) # False
WARNING:tensorflow:From /home/milton/anaconda3/envs/alibi/lib/python3.6/site-packages/tensorflow/python/compat/v2_compat.py:96: disable_resource_variables (from tensorflow.python.ops.variable_scope) is deprecated and will be removed in a future version. Instructions for updating: non-resource variables are not supported in the long term TF version: 2.4.1 Eager execution enabled: False
def create_nn_sofmax(input_dim):
model = Sequential()
model.add(Input(shape=input_dim))
# model.add(Dense(16, input_dim=input_dim, activation='relu'))
model.add(Dense(16, activation='relu'))
model.add(Dense(16, activation='relu'))
model.add(Dense(2, activation='softmax'))
return model
y_train_b = to_categorical(y_train)
y_test_b = to_categorical(y_test)
# class_names = ['Bad', 'Good']
nn = create_nn_sofmax(X_train.shape[1])
nn.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
nn.summary()
nn.fit(X_train, y_train_b, batch_size=128, epochs=50, verbose=1, shuffle=False, validation_split=0.1)
score = nn.evaluate(X_train, y_train_b, verbose=0)
print('Train accuracy:', score[1])
score = nn.evaluate(X_test, y_test_b, verbose=0)
print('Test accuracy:', score[1])
Model: "sequential" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= dense (Dense) (None, 16) 336 _________________________________________________________________ dense_1 (Dense) (None, 16) 272 _________________________________________________________________ dense_2 (Dense) (None, 2) 34 ================================================================= Total params: 642 Trainable params: 642 Non-trainable params: 0 _________________________________________________________________ Train on 217 samples, validate on 25 samples Epoch 1/50 217/217 [==============================] - 0s 307us/sample - loss: 0.6620 - acc: 0.6175 - val_loss: 0.6459 - val_acc: 0.6400 Epoch 2/50 217/217 [==============================] - 0s 13us/sample - loss: 0.6536 - acc: 0.6359 - val_loss: 0.6359 - val_acc: 0.6400 Epoch 3/50 217/217 [==============================] - 0s 13us/sample - loss: 0.6458 - acc: 0.6498 - val_loss: 0.6261 - val_acc: 0.6400 Epoch 4/50 217/217 [==============================] - 0s 13us/sample - loss: 0.6383 - acc: 0.6728 - val_loss: 0.6168 - val_acc: 0.6400 Epoch 5/50 217/217 [==============================] - 0s 12us/sample - loss: 0.6312 - acc: 0.6820 - val_loss: 0.6077 - val_acc: 0.6800 Epoch 6/50 217/217 [==============================] - 0s 19us/sample - loss: 0.6241 - acc: 0.6912 - val_loss: 0.5990 - val_acc: 0.6800 Epoch 7/50 217/217 [==============================] - 0s 14us/sample - loss: 0.6173 - acc: 0.6959 - val_loss: 0.5906 - val_acc: 0.7200 Epoch 8/50 217/217 [==============================] - 0s 15us/sample - loss: 0.6105 - acc: 0.7051 - val_loss: 0.5824 - val_acc: 0.7200 Epoch 9/50 217/217 [==============================] - 0s 14us/sample - loss: 0.6039 - acc: 0.7051 - val_loss: 0.5745 - val_acc: 0.7200 Epoch 10/50 217/217 [==============================] - 0s 13us/sample - loss: 0.5974 - acc: 0.7005 - val_loss: 0.5669 - val_acc: 0.7200 Epoch 11/50 217/217 [==============================] - 0s 15us/sample - loss: 0.5910 - acc: 0.7051 - val_loss: 0.5592 - val_acc: 0.7200 Epoch 12/50 217/217 [==============================] - 0s 16us/sample - loss: 0.5848 - acc: 0.7051 - val_loss: 0.5520 - val_acc: 0.7200 Epoch 13/50 217/217 [==============================] - 0s 14us/sample - loss: 0.5787 - acc: 0.7051 - val_loss: 0.5450 - val_acc: 0.7200 Epoch 14/50 217/217 [==============================] - 0s 14us/sample - loss: 0.5727 - acc: 0.7189 - val_loss: 0.5383 - val_acc: 0.7200 Epoch 15/50 217/217 [==============================] - 0s 13us/sample - loss: 0.5668 - acc: 0.7281 - val_loss: 0.5316 - val_acc: 0.7200 Epoch 16/50 217/217 [==============================] - 0s 13us/sample - loss: 0.5611 - acc: 0.7281 - val_loss: 0.5250 - val_acc: 0.7200 Epoch 17/50 217/217 [==============================] - 0s 16us/sample - loss: 0.5555 - acc: 0.7281 - val_loss: 0.5186 - val_acc: 0.7600 Epoch 18/50 217/217 [==============================] - 0s 15us/sample - loss: 0.5500 - acc: 0.7235 - val_loss: 0.5123 - val_acc: 0.7600 Epoch 19/50 217/217 [==============================] - 0s 15us/sample - loss: 0.5446 - acc: 0.7327 - val_loss: 0.5059 - val_acc: 0.8000 Epoch 20/50 217/217 [==============================] - 0s 13us/sample - loss: 0.5393 - acc: 0.7327 - val_loss: 0.4996 - val_acc: 0.8000 Epoch 21/50 217/217 [==============================] - 0s 14us/sample - loss: 0.5341 - acc: 0.7373 - val_loss: 0.4933 - val_acc: 0.8000 Epoch 22/50 217/217 [==============================] - 0s 13us/sample - loss: 0.5291 - acc: 0.7373 - val_loss: 0.4871 - val_acc: 0.8000 Epoch 23/50 217/217 [==============================] - 0s 18us/sample - loss: 0.5241 - acc: 0.7373 - val_loss: 0.4810 - val_acc: 0.8000 Epoch 24/50 217/217 [==============================] - 0s 13us/sample - loss: 0.5193 - acc: 0.7419 - val_loss: 0.4751 - val_acc: 0.8000 Epoch 25/50 217/217 [==============================] - 0s 18us/sample - loss: 0.5146 - acc: 0.7419 - val_loss: 0.4695 - val_acc: 0.8000 Epoch 26/50 217/217 [==============================] - 0s 16us/sample - loss: 0.5100 - acc: 0.7419 - val_loss: 0.4642 - val_acc: 0.8000 Epoch 27/50 217/217 [==============================] - 0s 14us/sample - loss: 0.5056 - acc: 0.7419 - val_loss: 0.4591 - val_acc: 0.8000 Epoch 28/50 217/217 [==============================] - 0s 16us/sample - loss: 0.5012 - acc: 0.7465 - val_loss: 0.4542 - val_acc: 0.8000 Epoch 29/50 217/217 [==============================] - 0s 16us/sample - loss: 0.4970 - acc: 0.7465 - val_loss: 0.4495 - val_acc: 0.8000 Epoch 30/50 217/217 [==============================] - 0s 13us/sample - loss: 0.4929 - acc: 0.7512 - val_loss: 0.4452 - val_acc: 0.8000 Epoch 31/50 217/217 [==============================] - 0s 13us/sample - loss: 0.4890 - acc: 0.7512 - val_loss: 0.4411 - val_acc: 0.8000 Epoch 32/50 217/217 [==============================] - 0s 15us/sample - loss: 0.4853 - acc: 0.7512 - val_loss: 0.4373 - val_acc: 0.8000 Epoch 33/50 217/217 [==============================] - 0s 14us/sample - loss: 0.4817 - acc: 0.7558 - val_loss: 0.4336 - val_acc: 0.8000 Epoch 34/50 217/217 [==============================] - 0s 15us/sample - loss: 0.4783 - acc: 0.7650 - val_loss: 0.4301 - val_acc: 0.8000 Epoch 35/50 217/217 [==============================] - 0s 17us/sample - loss: 0.4749 - acc: 0.7696 - val_loss: 0.4268 - val_acc: 0.8000 Epoch 36/50 217/217 [==============================] - 0s 16us/sample - loss: 0.4717 - acc: 0.7696 - val_loss: 0.4235 - val_acc: 0.8000 Epoch 37/50 217/217 [==============================] - 0s 14us/sample - loss: 0.4687 - acc: 0.7696 - val_loss: 0.4203 - val_acc: 0.8000 Epoch 38/50 217/217 [==============================] - 0s 14us/sample - loss: 0.4657 - acc: 0.7696 - val_loss: 0.4174 - val_acc: 0.8000 Epoch 39/50 217/217 [==============================] - 0s 12us/sample - loss: 0.4629 - acc: 0.7742 - val_loss: 0.4147 - val_acc: 0.8000 Epoch 40/50 217/217 [==============================] - 0s 16us/sample - loss: 0.4602 - acc: 0.7834 - val_loss: 0.4121 - val_acc: 0.8400 Epoch 41/50 217/217 [==============================] - 0s 13us/sample - loss: 0.4577 - acc: 0.7926 - val_loss: 0.4097 - val_acc: 0.8400 Epoch 42/50 217/217 [==============================] - 0s 13us/sample - loss: 0.4552 - acc: 0.7926 - val_loss: 0.4075 - val_acc: 0.8400 Epoch 43/50 217/217 [==============================] - 0s 12us/sample - loss: 0.4528 - acc: 0.7926 - val_loss: 0.4054 - val_acc: 0.8400
/home/milton/anaconda3/envs/alibi/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py:2325: UserWarning: `Model.state_updates` will be removed in a future version. This property should not be used in TensorFlow 2.0, as `updates` are applied automatically.
warnings.warn('`Model.state_updates` will be removed in a future version. '
Epoch 44/50 217/217 [==============================] - 0s 14us/sample - loss: 0.4505 - acc: 0.7926 - val_loss: 0.4036 - val_acc: 0.8400 Epoch 45/50 217/217 [==============================] - 0s 14us/sample - loss: 0.4483 - acc: 0.7926 - val_loss: 0.4020 - val_acc: 0.8400 Epoch 46/50 217/217 [==============================] - 0s 13us/sample - loss: 0.4462 - acc: 0.7926 - val_loss: 0.4006 - val_acc: 0.8400 Epoch 47/50 217/217 [==============================] - 0s 17us/sample - loss: 0.4441 - acc: 0.7880 - val_loss: 0.3992 - val_acc: 0.8400 Epoch 48/50 217/217 [==============================] - 0s 13us/sample - loss: 0.4421 - acc: 0.7880 - val_loss: 0.3978 - val_acc: 0.8800 Epoch 49/50 217/217 [==============================] - 0s 13us/sample - loss: 0.4401 - acc: 0.7880 - val_loss: 0.3965 - val_acc: 0.9200 Epoch 50/50 217/217 [==============================] - 0s 14us/sample - loss: 0.4382 - acc: 0.7926 - val_loss: 0.3952 - val_acc: 0.9200 Train accuracy: 0.80991733 Test accuracy: 0.91803277
class_names = ['Normal', 'Disease']
idx = 1
X_idx = X_test.iloc[[idx]]
Fornece a importância das features (feature importances) para qualquer modelo black-box. Acessa as informações de diferentes modelos, e garante o funcionamento da funcionalidade utilizando o método permutation importance para realizar modificações e verificar como muda o resultado. Disponível em https://github.com/TeamHG-Memex/eli5.
import eli5
from eli5.sklearn import PermutationImportance
eli5.show_weights(model, feature_names=X_train.columns.tolist())
| Weight | Feature |
|---|---|
| 0.1079 ± 0.1597 | ca |
| 0.1044 ± 0.1357 | oldpeak |
| 0.0986 ± 0.1384 | thalach |
| 0.0887 ± 0.1658 | cp_4 |
| 0.0871 ± 0.0952 | age |
| 0.0811 ± 0.0760 | chol |
| 0.0793 ± 0.1905 | thal_3 |
| 0.0710 ± 0.0759 | trestbps |
| 0.0594 ± 0.1569 | thal_7 |
| 0.0503 ± 0.1055 | slope |
| 0.0442 ± 0.0948 | exang |
| 0.0344 ± 0.0578 | sex |
| 0.0211 ± 0.0502 | cp_3 |
| 0.0166 ± 0.0401 | restecg_0 |
| 0.0164 ± 0.0400 | cp_1 |
| 0.0159 ± 0.0341 | restecg_2 |
| 0.0090 ± 0.0316 | cp_2 |
| 0.0088 ± 0.0235 | fbs |
| 0.0057 ± 0.0200 | thal_6 |
| 0.0001 ± 0.0010 | restecg_1 |
## Opção com PermutationImportance
# perm = PermutationImportance(model,n_iter=5).fit(X_train, y_train)
# eli5.show_weights(perm, feature_names=X_train.columns.tolist())
Sorrogate model que atual localmente ao redor dos vizinhos de uma instância. O modelo linear é capaz de capturar o comportamento naquele local.
import lime
import lime.lime_tabular
explainer = lime.lime_tabular.LimeTabularExplainer(X_train.values, training_labels=y_train, class_names=class_names,
feature_names=X_train.columns, kernel_width=3, discretize_continuous=True, verbose=False)
exp = explainer.explain_instance(X_idx.values[0], model.predict_proba, num_features=15)
exp.show_in_notebook()
Desenvolvido pelos mesmos criados do LIME, o algoritmo é uma extensão deste em que é utilizada Árvores de Decisão e são geradas "âncoras" das predições, que representam as características principais para alcançar determinada saída. Disponível em https://github.com/marcotcr/anchor.
from anchor import utils
from anchor import anchor_tabular
explainer = anchor_tabular.AnchorTabularExplainer(
class_names,
X_train.columns.tolist(),
X_train.values)
print('Prediction: ', explainer.class_names[model.predict(X_idx)[0]])
exp = explainer.explain_instance(X_idx.iloc[0].values, model.predict, threshold=0.90)
Prediction: Disease
print('Anchor: %s' % (' AND '.join(exp.names())))
print('Precision: %.2f' % exp.precision())
print('Coverage: %.2f' % exp.coverage())
Anchor: ca > 0.00 AND cp_4 > 0.00 AND exang > 0.00 Precision: 0.93 Coverage: 0.16
exp.show_in_notebook()
O objetivo do algoritmo SHAP é explicar a predição de uma instância computando as contribuições das features. SHAP tem forte fundamentação na teoria do jogos com a técnica Shapley values. Fornece interpretação local e global. O autor disponibilizou a implementação em https://github.com/slundberg/shap.
import shap
shap.initjs()
explainer = shap.TreeExplainer(model)
shap_values = explainer.shap_values(X_test)
shap.summary_plot(shap_values[1], X_test, max_display=15)
shap.force_plot(explainer.expected_value[1], shap_values[1][idx,:], X_idx)
Técnica de visualização para explicações que mostra o efeito marginal de uma ou duas variáveis nas predições dos modelos de ML. O método serve para mostrar as relações entre estes. Experimentos com biblioteca PDPbox (https://github.com/SauceCat/PDPbox).
from pdpbox import pdp, get_dataset, info_plots
pdp_goals = pdp.pdp_isolate(model=model, dataset=X_test, model_features=X_test.columns, feature='ca')
pdp.pdp_plot(pdp_goals, 'ca')
plt.show()
pdp_goals = pdp.pdp_isolate(model=model, dataset=X_test, model_features=X_test.columns, feature='thalach')
fig, axes = pdp.pdp_plot(
pdp_goals, 'thalach', frac_to_plot=0.5, plot_lines=True, x_quantile=False, show_percentile=True, plot_pts_dist=True
)
fig, axes, summary_df = info_plots.target_plot(
df=pd.concat([X_test, y_test.reset_index(drop=True)], axis=1), feature='sex', feature_name='sex', target='target'
)
fig, axes, summary_df = info_plots.actual_plot(
model=model, X=X_test, feature='sex', feature_name='sex', predict_kwds={}
)
fig, axes, summary_df = info_plots.target_plot(
df=pd.concat([X_test, y_test.reset_index(drop=True)], axis=1), feature=['thal_3', 'thal_6', 'thal_7'], feature_name='Thal', target='target'
)
inter1 = pdp.pdp_interact(
model=model, dataset=X_test, model_features=X_test.columns, features=['ca', 'oldpeak']
)
fig, axes = pdp.pdp_interact_plot(
pdp_interact_out=inter1, feature_names=['ca', 'oldpeak'], plot_type='contour', x_quantile=True, plot_pdp=True
)
Descreve como as features influenciam as predições dos modelos na média. É similar e inspirado no PDP, porém utiliza algumas propriedades matemáticas diferentes como a utilização de média e não assumir que os atributos são independentes e sem correlação, como ocorre no PDP. Os experimentos foram realizados com a biblioteca Alibi (https://github.com/SeldonIO/alibi).
from alibi.explainers import ALE, plot_ale
ale = ALE(model.predict_proba, feature_names=X_test.columns.tolist(), target_names=class_names)
exp = ale.explain(X_test.values)
plot_ale(exp, features=['ca', 'oldpeak'], fig_kw={'figwidth': 12, 'figheight': 5}, sharey='all')
plt.show()
O algoritmo tem o objetivo de identificar o conjunto mínimo de modificações nas características de uma instância para produzir uma predição predefida, neste caso, para obter a outra classe. Nesse experimento foi utilizado uma variação do algoritmo orginal Counterfactual, chamado CounterfacualProto. Este algoritmo combina o original com Protótipos, alcançando o resultado com maior velocidade, além de implementar uma versão que funciona bem com variáveis categóricas. Foi utilizado a implementação da biblioteca Alibi (https://github.com/SeldonIO/alibi).
from alibi.explainers import CounterFactual, CounterFactualProto
from alibi.utils.mapping import ohe_to_ord_shape, ohe_to_ord, ord_to_ohe
rng = feature_range
rng_shape = (1,) + (13,)
feature_range_numeric = ((np.ones(rng_shape) * rng[0]).astype(np.float32), (np.ones(rng_shape) * rng[1]).astype(np.float32))
cat_cols_dict = {0: 3, 3: 3, 6:4}
cf = CounterFactualProto(nn,
X_idx.shape,
beta=.01,
cat_vars=cat_cols_dict,
ohe=True,
max_iterations=1000,
feature_range=feature_range_numeric,
c_init=1.,
c_steps=5
)
cf.fit(X_train.values)
explanation = cf.explain(X_idx.values)
WARNING:tensorflow:From /home/milton/anaconda3/envs/alibi/lib/python3.6/site-packages/alibi/utils/tf.py:26: The name tf.keras.backend.get_session is deprecated. Please use tf.compat.v1.keras.backend.get_session instead.
pred_class_original = explanation['orig_class']
proba_original = explanation['orig_proba'][0][pred_class_original]
print(f'Original prediction: {pred_class_original} with probability {proba_original}')
pred_class_cf = explanation.cf['class']
proba_cf = explanation.cf['proba'][0][pred_class_cf]
print(f'Counterfactual prediction: {pred_class_cf} with probability {proba_cf}')
Original prediction: 1 with probability 0.6100563406944275 Counterfactual prediction: 0 with probability 0.5155072212219238
X_cf = explanation['cf']['X']
def highlight_ce(s, col, ncols):
if (type(s[col]) != str):
if (s[col] > 0):
return(['background-color: yellow']*ncols)
if (s[col] < 0):
return(['background-color: red']*ncols)
return(['background-color: white']*ncols)
def create_dataframe_instances(X_idx, X_new, name):
idx = X_idx.index[0]
Xpn = X_new #explanation['cf']['X']
classes = [ class_names[np.argmax(nn.predict_proba(X_idx))], class_names[np.argmax(nn.predict_proba(Xpn))], 'NIL' ]
delta_re = Xpn - X_idx.values
delta_re = np.around(delta_re.astype(np.double), 2)
delta_re[np.absolute(delta_re) < 1e-4] = 0
X_idx_df = X_idx.iloc[0]
Xpn_df = pd.Series(Xpn[0], index=X_train.columns)
delta_re_df = pd.Series(delta_re[0], index=X_train.columns)
X_idx_df.name = 'X'
Xpn_df.name = name
delta_re_df.name = name + ' - X'
dfre = pd.concat([X_idx_df, Xpn_df, delta_re_df], axis=1)
dfre.loc['Class', :] = classes
return dfre
def highlight_ce(s, col, ncols):
if (type(s[col]) != str):
if (s[col] > 0):
return(['background-color: yellow']*ncols)
if (s[col] < 0):
return(['background-color: red']*ncols)
return(['background-color: white']*ncols)
df_result = create_dataframe_instances(X_idx, X_cf, 'CF')
df_result.style.apply(highlight_ce, col='CF - X', ncols=3, axis=1)
| X | CF | CF - X | |
|---|---|---|---|
| thal_3 | 1.000000 | 1.000000 | 0.000000 |
| thal_6 | 0.000000 | 0.000000 | 0.000000 |
| thal_7 | 0.000000 | 0.000000 | 0.000000 |
| restecg_0 | 0.000000 | 1.000000 | 1.000000 |
| restecg_1 | 0.000000 | 0.000000 | 0.000000 |
| restecg_2 | 1.000000 | 0.000000 | -1.000000 |
| cp_1 | 0.000000 | 0.000000 | 0.000000 |
| cp_2 | 0.000000 | 0.000000 | 0.000000 |
| cp_3 | 0.000000 | 0.000000 | 0.000000 |
| cp_4 | 1.000000 | 1.000000 | 0.000000 |
| age | 0.520833 | 0.559805 | 0.040000 |
| sex | 1.000000 | 1.000000 | 0.000000 |
| trestbps | 0.150943 | 0.131322 | -0.020000 |
| chol | 0.274914 | 0.321727 | 0.050000 |
| fbs | 0.000000 | 0.000000 | 0.000000 |
| thalach | 0.282443 | 0.242497 | -0.040000 |
| exang | 1.000000 | 0.966759 | -0.030000 |
| oldpeak | 0.000000 | 0.000000 | 0.000000 |
| slope | 0.500000 | 0.393760 | -0.110000 |
| ca | 0.333333 | 0.298004 | -0.040000 |
| Class | Disease | Normal | NIL |